package com.hanxiaozhang.delayoperation.timewheel;

import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;


/**
 * 功能描述: <br>
 * 〈时间轮的一格，即时间格〉
 * <p>
 * 采用双向链表实现，用于封装位于特定时间区间范围内的所有的延时任务
 *
 * @Author:hanxinghua
 * @Date: 2024/2/26
 */
class TimerTaskList implements Delayed {

    /*
     TimerTaskList使用虚拟根目录的双向链表实现
     root.next 指向 头部
     root.prev 指向 尾部
     */

    /**
     * 根节点(空根节点)
     */
    private final TimerTaskEntry root = new TimerTaskEntry(null, -1);

    /**
     * 链表所在bucket的过期时间戳
     */
    private final AtomicLong expiration = new AtomicLong(-1L);

    private final AtomicInteger taskCounter;

    TimerTaskList(AtomicInteger taskCounter) {
        root.setNext(root);
        root.setPrev(root);
        this.taskCounter = taskCounter;
    }

    /**
     * 循环执行每一个任务
     *
     * @param f
     */
    void foreach(Consumer<TimerTask> f) {
        synchronized (this) {
            TimerTaskEntry entry = root.getNext();
            while (entry != root) {
                final TimerTaskEntry nextEntry = entry.getNext();

                if (!entry.cancelled()) {
                    f.accept(entry.getTimerTask());
                }

                entry = nextEntry;
            }
        }
    }

    /**
     * 将给定定时任务插入到链表
     *
     * @param timerTaskEntry
     */
    void add(TimerTaskEntry timerTaskEntry) {
        boolean done = false;
        while (!done) {

            // 在添加之前尝试移除该定时任务，保证该任务没有在其他链表中
            timerTaskEntry.remove();

            synchronized (this) {
                synchronized (timerTaskEntry) {
                    if (timerTaskEntry.getList() == null) {

                        // 从尾部插入
                        final TimerTaskEntry tail = root.getPrev();
                        timerTaskEntry.setNext(root);
                        timerTaskEntry.setPrev(tail);
                        // 更新list
                        timerTaskEntry.setList(this);
                        tail.setNext(timerTaskEntry);
                        root.setPrev(timerTaskEntry);
                        taskCounter.incrementAndGet();
                        done = true;
                    }
                }
            }
        }
    }

    /**
     * 从链表中移除定时任务
     *
     * @param timerTaskEntry
     */
    void remove(TimerTaskEntry timerTaskEntry) {
        synchronized (this) {
            synchronized (timerTaskEntry) {
                if (timerTaskEntry.getList() == this) {
                    // 当前节点的后继节点的前继 赋值 当前节点的前继
                    timerTaskEntry.getNext().setPrev(timerTaskEntry.getPrev());
                    // 当前节点的前继节点的后继 赋值 当前节点的后继
                    timerTaskEntry.getPrev().setNext(timerTaskEntry.getNext());
                    // 清空数据
                    timerTaskEntry.setNext(null);
                    timerTaskEntry.setPrev(null);
                    // 更新list为null
                    timerTaskEntry.setList(null);
                    taskCounter.decrementAndGet();
                }
            }
        }
    }

    /**
     * 清空链表中的所有元素，并执行节点中的逻辑
     * <p>
     * 该方法用于将高层次时间轮 Bucket 上的定时任务重新插入回低层次的 Bucket 中
     *
     * @param f
     */
    void flush(Consumer<TimerTaskEntry> f) {
        synchronized (this) {
            // 找到链表第一个元素
            TimerTaskEntry head = root.getNext();
            // 开始遍历链表
            while (head != root) {
                // 移除遍历到的链表元素
                remove(head);
                // 执行传入参数f的逻辑
                f.accept(head);
                head = root.getNext();
            }
            // 清空过期时间设置
            expiration.set(-1L);
        }
    }

    @Override
    public long getDelay(TimeUnit unit) {
        return unit.convert(
                Math.max(getExpiration() - TimeUnit.NANOSECONDS.toMillis(System.nanoTime()), 0),
                TimeUnit.MILLISECONDS);
    }

    @Override
    public int compareTo(Delayed o) {
        TimerTaskList other = (TimerTaskList) o;

        if (getExpiration() < other.getExpiration()) {
            return -1;
        } else if (getExpiration() > other.getExpiration()) {
            return 1;
        } else {
            return 0;
        }
    }

    /**
     * 获取 bucket的过期时间
     *
     * @return
     */
    Long getExpiration() {
        return expiration.get();
    }

    /**
     * 设置 bucket的过期时间
     *
     * @param expirationMs
     * @return 如果过期时间设置成，则返回true
     */
    boolean setExpiration(Long expirationMs) {
        return expiration.getAndSet(expirationMs) != expirationMs;
    }
}
